{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    " \n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
    "\n",
    "def llm(prompt, stop=[\"\\n\"]):\n",
    "    response = openai.Completion.create(\n",
    "      model=\"text-davinci-002\",\n",
    "      prompt=prompt,\n",
    "      temperature=0,\n",
    "      max_tokens=100,\n",
    "      top_p=1,\n",
    "      frequency_penalty=0.0,\n",
    "      presence_penalty=0.0,\n",
    "      stop=stop\n",
    "    )\n",
    "    return response[\"choices\"][0][\"text\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wikienv, wrappers\n",
    "env = wikienv.WikiEnv()\n",
    "env = wrappers.HotPotQAWrapper(env, split=\"dev\")\n",
    "env = wrappers.LoggingWrapper(env)\n",
    "\n",
    "def step(env, action):\n",
    "    attempts = 0\n",
    "    while attempts < 10:\n",
    "        try:\n",
    "            return env.step(action)\n",
    "        except requests.exceptions.Timeout:\n",
    "            attempts += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ReAct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sys\n",
    "\n",
    "folder = './prompts/'\n",
    "prompt_file = 'prompts_naive.json'\n",
    "with open(folder + prompt_file, 'r') as f:\n",
    "    prompt_dict = json.load(f)\n",
    "\n",
    "webthink_examples = prompt_dict['webthink_simple6']\n",
    "instruction = \"\"\"Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: \n",
    "(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.\n",
    "(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.\n",
    "(3) Finish[answer], which returns the answer and finishes the task.\n",
    "Here are some examples.\n",
    "\"\"\"\n",
    "webthink_prompt = instruction + webthink_examples\n",
    "\n",
    "def webthink(idx=None, prompt=webthink_prompt, to_print=True):\n",
    "    question = env.reset(idx=idx)\n",
    "    if to_print:\n",
    "        print(idx, question)\n",
    "    prompt += question + \"\\n\"\n",
    "    n_calls, n_badcalls = 0, 0\n",
    "    for i in range(1, 8):\n",
    "        n_calls += 1\n",
    "        thought_action = llm(prompt + f\"Thought {i}:\", stop=[f\"\\nObservation {i}:\"])\n",
    "        try:\n",
    "            thought, action = thought_action.strip().split(f\"\\nAction {i}: \")\n",
    "        except:\n",
    "            print('ohh...', thought_action)\n",
    "            n_badcalls += 1\n",
    "            n_calls += 1\n",
    "            thought = thought_action.strip().split('\\n')[0]\n",
    "            action = llm(prompt + f\"Thought {i}: {thought}\\nAction {i}:\", stop=[f\"\\n\"]).strip()\n",
    "        obs, r, done, info = step(env, action[0].lower() + action[1:])\n",
    "        obs = obs.replace('\\\\n', '')\n",
    "        step_str = f\"Thought {i}: {thought}\\nAction {i}: {action}\\nObservation {i}: {obs}\\n\"\n",
    "        prompt += step_str\n",
    "        if to_print:\n",
    "            print(step_str)\n",
    "        if done:\n",
    "            break\n",
    "    if not done:\n",
    "        obs, r, done, info = step(env, \"finish[]\")\n",
    "    if to_print:\n",
    "        print(info, '\\n')\n",
    "    info.update({'n_calls': n_calls, 'n_badcalls': n_badcalls, 'traj': prompt})\n",
    "    return r, info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import time\n",
    "idxs = list(range(7405))\n",
    "random.Random(233).shuffle(idxs)\n",
    "\n",
    "rs = []\n",
    "infos = []\n",
    "old_time = time.time()\n",
    "for i in idxs[:500]:\n",
    "    r, info = webthink(i, to_print=True)\n",
    "    rs.append(info['em'])\n",
    "    infos.append(info)\n",
    "    print(sum(rs), len(rs), sum(rs) / len(rs), (time.time() - old_time) / len(rs))\n",
    "    print('-----------')\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
